"""
Phase 1: Data Assembly for Japanification Project
==================================================
Build japanification panel from multilateral project data.

Source: multilateral/data/processed/full_panel.csv
Output: japanification/data/processed/japan_panel.csv
"""

import pandas as pd
import numpy as np
from pathlib import Path

# Paths
PROJECT_DIR = Path("/mnt/c/demographics_capital_flows")
MULTILATERAL_DIR = PROJECT_DIR / "multilateral"
JAPAN_DIR = PROJECT_DIR / "japanification"
PROCESSED_DIR = JAPAN_DIR / "data" / "processed"
TABLE_DIR = JAPAN_DIR / "output" / "tables"

PROCESSED_DIR.mkdir(parents=True, exist_ok=True)
TABLE_DIR.mkdir(parents=True, exist_ok=True)


def main():
    print("=" * 70)
    print("PHASE 1: Data Assembly for Japanification Project")
    print("=" * 70)

    # ---------------------------------------------------------------
    # 1. Load full_panel.csv
    # ---------------------------------------------------------------
    fp = pd.read_csv(MULTILATERAL_DIR / "followup" / "data" / "processed" / "full_panel.csv")
    print(f"\nLoaded full_panel: {fp.shape[0]:,} rows, {fp['iso3'].nunique()} countries, "
          f"{fp['year'].min()}-{fp['year'].max()}")

    # ---------------------------------------------------------------
    # 2. Filter to 1990-2024
    # ---------------------------------------------------------------
    df = fp[(fp['year'] >= 1990) & (fp['year'] <= 2024)].copy()
    print(f"After 1990-2024 filter: {len(df):,} rows, {df['iso3'].nunique()} countries")

    # ---------------------------------------------------------------
    # 3. Require non-missing demographics + macro
    # ---------------------------------------------------------------
    # Inflation: prefer WEO inflation, fallback to cpi_inflation_wb
    df['inflation_japan'] = df['inflation']
    wb_fill = df['inflation_japan'].isna() & df['cpi_inflation_wb'].notna()
    df.loc[wb_fill, 'inflation_japan'] = df.loc[wb_fill, 'cpi_inflation_wb']
    print(f"Inflation: WEO={df['inflation'].notna().sum():,}, "
          f"WB fallback filled={wb_fill.sum():,}, "
          f"total={df['inflation_japan'].notna().sum():,}")

    # Core requirement: growth, inflation, Z polynomials
    core_mask = (
        df['rgdp_growth'].notna() &
        df['inflation_japan'].notna() &
        df['Z_1'].notna() & df['Z_2'].notna() & df['Z_3'].notna()
    )
    df = df[core_mask].copy()
    print(f"After core variable filter: {len(df):,} rows, {df['iso3'].nunique()} countries")

    # ---------------------------------------------------------------
    # 4. Rate variable hierarchy
    # ---------------------------------------------------------------
    # Prefer govt_bond_10y > policy_rate > lending_rate
    df['rate_japan'] = np.nan
    df['rate_source'] = ''

    # Level 1: 10-year government bond yield
    m1 = df['govt_bond_10y'].notna()
    df.loc[m1, 'rate_japan'] = df.loc[m1, 'govt_bond_10y']
    df.loc[m1, 'rate_source'] = 'govt_bond_10y'

    # Level 2: policy rate (where bond yield missing)
    m2 = df['rate_japan'].isna() & df['policy_rate'].notna()
    df.loc[m2, 'rate_japan'] = df.loc[m2, 'policy_rate']
    df.loc[m2, 'rate_source'] = 'policy_rate'

    # Level 3: lending rate (log-transformed to handle extremes)
    m3 = df['rate_japan'].isna() & df['lending_rate'].notna()
    df.loc[m3, 'rate_japan'] = df.loc[m3, 'lending_rate']
    df.loc[m3, 'rate_source'] = 'lending_rate'

    rate_counts = df['rate_source'].value_counts()
    print(f"\nRate variable hierarchy:")
    for src, cnt in rate_counts.items():
        if src:
            print(f"  {src}: {cnt:,}")
    print(f"  missing: {(df['rate_source'] == '').sum():,}")

    # ---------------------------------------------------------------
    # 5. Add dependency ratios
    # ---------------------------------------------------------------
    dep = pd.read_csv(MULTILATERAL_DIR / "data" / "processed" / "dependency_ratios.csv")
    dep_cols = ['iso3', 'year', 'youth_dep', 'old_dep', 'total_dep', 'working_age_share']
    # Only merge columns not already present
    existing = [c for c in dep_cols[2:] if c in df.columns]
    if existing:
        # Use existing columns from full_panel (already merged in pipeline)
        print(f"\nDependency ratios already in panel: {existing}")
    else:
        df = df.merge(dep[dep_cols], on=['iso3', 'year'], how='left')
        print(f"\nMerged dependency ratios: {df['old_dep'].notna().sum():,} non-missing")

    # ---------------------------------------------------------------
    # 6. Add future OADR for projection phase
    # ---------------------------------------------------------------
    if 'oadr_plus20' not in df.columns:
        future = pd.read_csv(MULTILATERAL_DIR / "data" / "processed" / "future_oadr.csv")
        df = df.merge(future[['iso3', 'year', 'oadr_plus20']], on=['iso3', 'year'], how='left')
    print(f"Future OADR available: {df['oadr_plus20'].notna().sum():,} obs")

    # ---------------------------------------------------------------
    # 7. Winsorize inflation at p1/p99
    # ---------------------------------------------------------------
    p1 = df['inflation_japan'].quantile(0.01)
    p99 = df['inflation_japan'].quantile(0.99)
    n_winsorized = ((df['inflation_japan'] < p1) | (df['inflation_japan'] > p99)).sum()
    df['inflation_japan'] = df['inflation_japan'].clip(lower=p1, upper=p99)
    print(f"\nInflation winsorized at [{p1:.1f}, {p99:.1f}]: {n_winsorized} obs clipped")

    # ---------------------------------------------------------------
    # 8. Compute delta OADR (speed of aging)
    # ---------------------------------------------------------------
    df = df.sort_values(['iso3', 'year'])
    df['delta_old_dep'] = df.groupby('iso3')['old_dep'].diff()

    # ---------------------------------------------------------------
    # 9. Select output columns
    # ---------------------------------------------------------------
    out_cols = [
        'iso3', 'year',
        # Japanification components
        'rgdp_growth', 'inflation_japan', 'rate_japan', 'rate_source',
        # Demographics - polynomials
        'Z_1', 'Z_2', 'Z_3',
        # Demographics - dependency ratios
        'youth_dep', 'old_dep', 'total_dep', 'working_age_share',
        'delta_old_dep', 'oadr_plus20',
        # Demographics - other
        'life_expectancy',
        # Controls
        'fiscal_bal_gdp', 'kaopen', 'log_rel_opw', 'nfa_gdp_lag',
        'expected_growth', 'output_gap',
        # Income
        'gdp_pc_ppp', 'rgdp_growth',
        # Rate details (for decomposition)
        'govt_bond_10y', 'policy_rate', 'lending_rate',
        'real_bond_10y', 'real_bond_10y_diff',
        # CA/GDP for cross-validation
        'ca_gdp',
        # Interactions (pre-computed)
        'Z_1_x_kaopen', 'Z_2_x_kaopen', 'Z_3_x_kaopen',
        # Age bins for coefficient recovery
    ] + [f'd_n_{i}' for i in range(1, 18)]

    # Only keep columns that exist
    out_cols = [c for c in out_cols if c in df.columns]
    # Remove duplicates while preserving order
    seen = set()
    out_cols = [c for c in out_cols if not (c in seen or seen.add(c))]

    panel = df[out_cols].copy()

    # ---------------------------------------------------------------
    # 10. Save
    # ---------------------------------------------------------------
    panel.to_csv(PROCESSED_DIR / "japan_panel.csv", index=False)
    print(f"\n{'=' * 70}")
    print(f"Saved: {PROCESSED_DIR / 'japan_panel.csv'}")
    print(f"  {len(panel):,} obs, {panel['iso3'].nunique()} countries, "
          f"{panel['year'].min()}-{panel['year'].max()}")
    print(f"  With any rate: {panel['rate_japan'].notna().sum():,} obs "
          f"({panel.loc[panel['rate_japan'].notna(), 'iso3'].nunique()} countries)")
    print(f"  Core (growth + inflation + Z): {len(panel):,}")

    # Summary statistics
    print(f"\n{'=' * 70}")
    print("SUMMARY STATISTICS")
    print("=" * 70)
    summary_vars = ['rgdp_growth', 'inflation_japan', 'rate_japan',
                    'old_dep', 'Z_1', 'Z_2', 'Z_3', 'kaopen']
    summary = panel[summary_vars].describe().T
    summary['non_missing'] = panel[summary_vars].notna().sum()
    print(summary[['non_missing', 'mean', 'std', 'min', '25%', '50%', '75%', 'max']]
          .to_string(float_format='%.3f'))

    # Save summary
    summary.to_csv(TABLE_DIR / "phase1_summary_stats.csv")

    return panel


if __name__ == "__main__":
    panel = main()
